﻿using Common;
using FrontEnd;
using Linguist;
using Linguist.Acoustic.Tiedstate;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Util;

namespace Linguist.Acoustic.Tiedstate
{
    /// <summary>
    /// 
    /// This class is used for estimating a MLLR transform for each cluster of data.
    /// The clustering must be previously performed using
    /// ClusteredDensityFileData.java
    /// 
    /// @author Bogdan Petcu
    /// </summary>
    public class Stats
    {
      	private ClusteredDensityFileData means;
	    private double[][][][][] regLs;
	    private double[][][][] regRs;
	    private int nrOfClusters;
	    private Sphinx3Loader loader;
	    private float varFlor;
	    private LogMath logMath = LogMath.getLogMath();

	    public Stats(ILoader loader, ClusteredDensityFileData means) 
        {
		    this.loader = (Sphinx3Loader) loader;
		    this.nrOfClusters = means.getNumberOfClusters();
		    this.means = means;
		    this.varFlor = (float) 1e-5;
		    this.invertVariances();
		    this.init();
	    }

	    private void init() 
        {
		    int len = loader.getVectorLength()[0];
		    this.regLs = new double[nrOfClusters][][][][];
		    this.regRs = new double[nrOfClusters][][][];

		    for (int i = 0; i < nrOfClusters; i++) {
			    this.regLs[i] = new double[loader.getNumStreams()][][][];
			    this.regRs[i] = new double[loader.getNumStreams()][][];

			    for (int j = 0; j < loader.getNumStreams(); j++) {
				    len = loader.getVectorLength()[j];
				    this.regLs[i][j] = new double[len][][];
				    this.regRs[i][j] = new double[len][];
			    }
		    }
	    }

	    public ClusteredDensityFileData getClusteredData() {
		    return this.means;
	    }

	    public double[][][][][] getRegLs() {
		    return regLs;
	    }

	    public double[][][][] getRegRs() {
		    return regRs;
	    }

	    /**
	    /// Used for inverting variances.
	     */
	    private void invertVariances() 
        {

		    for (int i = 0; i < loader.getNumStates(); i++) {
			    for (int k = 0; k < loader.getNumGaussiansPerState(); k++) {
				    for (int l = 0; l < loader.getVectorLength()[0]; l++) {
					    if (loader.getVariancePool().get(
							    i* loader.getNumGaussiansPerState() + k)[l] <= 0.0) {
						    this.loader.getVariancePool().get(i
								   * loader.getNumGaussiansPerState() + k)[l] = (float) 0.5;
					    } else if (loader.getVariancePool().get(
							    i* loader.getNumGaussiansPerState() + k)[l] < varFlor) {
						    this.loader.getVariancePool().get(i
								   * loader.getNumGaussiansPerState() + k)[l] = (float) (1.0 / varFlor);
					    } else {
						    this.loader.getVariancePool().get(i
								   * loader.getNumGaussiansPerState() + k)[l] = (float) (1.0 / loader
								    .getVariancePool().get(
										    i* loader.getNumGaussiansPerState()
												    + k)[l]);
					    }
				    }
			    }
		    }
	    }

	    /**
	    /// Computes posterior values for the each component.
	    /// 
	    /// @param componentScores
	    ///            from which the posterior values are computed.
	    /// @return posterior values for all components.
	     */
	    private float[] computePosterios(float[] componentScores) 
        {
		    float max;
		    float[] posteriors = componentScores;

		    max = posteriors[0];

		    for (int i = 1; i < componentScores.Length; i++) 
            {
			    if (posteriors[i] > max) {
				    max = posteriors[i];
			    }
		    }

		    for (int i = 0; i < componentScores.Length; i++) 
            {
			    posteriors[i] = (float) logMath.logToLinear(posteriors[i] - max);
		    }

		    return posteriors;
	    }

	    /**
	    /// This method is used for directly collect and use counts. The counts are
	    /// collected and stored separately for each cluster.
	    /// 
	    /// @param result
	    ///            Result object to collect counts from.
	     */
	    public void collect(ISpeechResult result)
        {
		    IToken token = result.getResult().getBestToken();
		    float[] componentScore, featureVector, posteriors, tmean;
		    float dnom, wtMeanVar, wtDcountVar, wtDcountVarMean, mean;
		    int mId, len, cluster;

		    if (token == null)
			    throw new Exception("Best token not found!");

		    do {
			    FloatData feature = (FloatData) token.getData();
			    ISearchState ss = token.getSearchState();

			    if (!(ss is IHMMSearchState && ss.isEmitting())) {
				    token = token.getPredecessor();
				    continue;
			    }

			    componentScore = token.calculateComponentScore(feature);
			    featureVector = FloatData.toFloatData(feature).getValues();
			    mId = (int) ((IHMMSearchState)token.getSearchState()).getHMMState().getMixtureId();
			    posteriors = this.computePosterios(componentScore);
			    len = loader.getVectorLength()[0];

			    for (int i = 0; i < componentScore.Length; i++) {
				    cluster = means.getClassIndex(mId
						   * loader.getNumGaussiansPerState() + i);
				    dnom = posteriors[i];
				    if (dnom > 0.0) {
					    tmean = loader.getMeansPool().get(
							    mId* loader.getNumGaussiansPerState() + i);

					    for (int j = 0; j < featureVector.Length; j++) {
						    mean = posteriors[i]* featureVector[j];
						    wtMeanVar = mean
								   * loader.getVariancePool().get(mId
										   * loader.getNumGaussiansPerState()
										    + i)[j];
						    wtDcountVar = dnom
								   * loader.getVariancePool().get(mId
										   * loader.getNumGaussiansPerState()
										    + i)[j];

						    for (int p = 0; p < featureVector.Length; p++) {
							    wtDcountVarMean = wtDcountVar* tmean[p];

							    for (int q = p; q < featureVector.Length; q++) {
								    regLs[cluster][0][j][p][q] += wtDcountVarMean
										   * tmean[q];
							    }
							    regLs[cluster][0][j][p][len] += wtDcountVarMean;
							    regRs[cluster][0][j][p] += wtMeanVar* tmean[p];
						    }
						    regLs[cluster][0][j][len][len] += wtDcountVar;
						    regRs[cluster][0][j][len] += wtMeanVar;

					    }
				    }
			    }

			    token = token.getPredecessor();
		    } while (token != null);
	    }

	    /**
	    /// Fill lower part of Legetter's set of G matrices.
	     */
	    public void fillRegLowerPart() {
		    for (int i = 0; i < this.nrOfClusters; i++) {
			    for (int j = 0; j < loader.getNumStreams(); j++) {
				    for (int l = 0; l < loader.getVectorLength()[j]; l++) {
					    for (int p = 0; p <= loader.getVectorLength()[j]; p++) {
						    for (int q = p + 1; q <= loader.getVectorLength()[j]; q++) {
							    regLs[i][j][l][q][p] = regLs[i][j][l][p][q];
						    }
					    }
				    }
			    }
		    }
	    }

        public Transform createTransform() 
        {
            Transform transform = new Transform(loader, nrOfClusters);
            transform.update(this);
            return transform;
        }

    }
}
